{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Colab-RIFE.ipynb",
      "private_outputs": true,
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zY486A6vUokn"
      },
      "source": [
        "# Colab-RIFE\n",
        "\n",
        "Original repo: [hzwer/arXiv2020-RIFE](https://github.com/hzwer/arXiv2020-RIFE)\n",
        "\n",
        "My fork: [styler00dollar/Colab-RIFE](https://github.com/styler00dollar/Colab-RIFE)"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "SQnkDEw3Av92",
        "cellView": "form"
      },
      "source": [
        "#@title Check GPU\n",
        "!nvidia-smi"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "K97x7-ag5FtA",
        "cellView": "form"
      },
      "source": [
        "#@title Setup (+ v1.2 model)\n",
        "!git clone https://github.com/hzwer/Arxiv2020-RIFE\n",
        "!mkdir /content/Arxiv2020-RIFE/input\n",
        "!mkdir /content/Arxiv2020-RIFE/output\n",
        "!mkdir /content/Arxiv2020-RIFE/train_log\n",
        "\n",
        "!sudo apt install ffmpeg\n",
        "\n",
        "!pip install gdown\n",
        "%cd /content/Arxiv2020-RIFE/train_log\n",
        "# fist model\n",
        "#!gdown --id 1c1R7iF-ypN6USo-D2YH_ORtaH3tukSlo\n",
        "#!7z e RIFE_trained_model.zip\n",
        "\n",
        "# 2020.11.19, v1.2\n",
        "!gdown --id 1zYc3PEN4t6GOUoVYJjvcXoMmM3kFDNGS\n",
        "!7z e RIFE_trained_model_new.zip"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8KmTzccN47mL",
        "cellView": "form"
      },
      "source": [
        "#@title test.py (Original by [Jacob_](https://github.com/buildist) with some modifications)\n",
        "# Original: https://pastebin.com/raw/XzZgeqbs\n",
        "%%writefile /content/Arxiv2020-RIFE/test.py\n",
        "import cv2\n",
        "import os\n",
        "import numpy as np\n",
        "import shutil\n",
        "import torch\n",
        "import torchvision\n",
        "from torchvision import transforms\n",
        "from torch.nn import functional as F\n",
        "from PIL import Image\n",
        "from model.RIFE import Model\n",
        "from glob import glob\n",
        "from imageio import imread, imsave\n",
        "from torch.autograd import Variable\n",
        "from tqdm import tqdm\n",
        "video_paths = ['/content/Arxiv2020-RIFE/']\n",
        "video_paths_ = glob(video_paths[0] + '/**/*.mp4', recursive=True)\n",
        "\n",
        "\n",
        "for path in video_paths_:\n",
        "    name = os.path.basename(path)\n",
        "    print(name)\n",
        "    if os.path.isfile(path):\n",
        "        output_path = path.replace('input', 'tmp')\n",
        "        os.makedirs(output_path, exist_ok = True)\n",
        "        video_paths.append(output_path)\n",
        "        if not os.path.isfile('{:s}/00000001.png'.format(output_path)):\n",
        "            os.system('ffmpeg -i {:s} {:s}/%08d.png -hide_banner'.format(path, output_path));\n",
        "\n",
        "RIFE_model = Model()\n",
        "RIFE_model.load_model('./train_log')\n",
        "RIFE_model.eval()\n",
        "RIFE_model.device()\n",
        "\n",
        "for path in video_paths_:\n",
        "    name = os.path.basename(path)\n",
        "    length = len(glob(path + '/*.png'))\n",
        "    sr_output_path = path.replace(name, name+'_sr_out')\n",
        "    os.makedirs(sr_output_path, exist_ok = True)\n",
        "    interp_output_path = path.replace(name, name+'_interp_out')\n",
        "    os.makedirs(interp_output_path, exist_ok = True)\n",
        "    output_path = path.replace('tmp', 'output')\n",
        "    if os.path.isfile(output_path):\n",
        "        continue\n",
        "\n",
        "    with torch.no_grad():\n",
        "        if not os.path.isfile('{:s}/00000001.png'.format(interp_output_path)):\n",
        "            output_frame_number = 1\n",
        "            for input_frame_number in tqdm(range(1, length)):\n",
        "                frame_0_path = 'tmp/'+'{:s}/{:08d}.png'.format(name, input_frame_number)\n",
        "                frame_1_path = 'tmp/'+'{:s}/{:08d}.png'.format(name, input_frame_number + 1)\n",
        "\n",
        "                frame_0 = cv2.imread(frame_0_path)\n",
        "                frame_1 = cv2.imread(frame_1_path)\n",
        "\n",
        "                h, w, _ = frame_0.shape\n",
        "                ph = h // 32 * 32+32\n",
        "                pw = w // 32 * 32+32\n",
        "                padding = (0, pw - w, 0, ph - h)\n",
        "                frame_0 = torch.tensor(frame_0.transpose(2, 0, 1)).cuda() / 255.\n",
        "                frame_1 = torch.tensor(frame_1.transpose(2, 0, 1)).cuda() / 255.\n",
        "\n",
        "                res = RIFE_model.inference(frame_0.unsqueeze(0), frame_1.unsqueeze(0)) * 255\n",
        "\n",
        "\n",
        "                shutil.copyfile(frame_0_path, '{:s}/{:08d}.png'.format(interp_output_path, output_frame_number))\n",
        "                output_frame_number += 1\n",
        "                cv2.imwrite('{:s}/{:08d}.png'.format(interp_output_path, output_frame_number), res[0].cpu().numpy().transpose(1, 2, 0)[:h, :w])\n",
        "                output_frame_number += 1\n",
        "\n",
        "                if output_frame_number == length*2 - 1:\n",
        "                    shutil.copyfile(frame_1_path, '{:s}/{:08d}.png'.format(interp_output_path, output_frame_number))\n",
        "                    output_frame_number += 1\n",
        "    os.makedirs(output_path.replace(os.path.basename(output_path), ''), exist_ok = True)\n",
        "    source_path = path.replace('tmp', 'input')\n",
        "    #os.system('ffmpeg -y -f image2 -r 30 -i {:s}/%08d.png -i {:s} -c copy -map 0:0 -map 1:1? -crf 3 -c:v mjpeg -q:v 3 {:s} -hide_banner'.format(interp_output_path, source_path, output_path));\n",
        "    os.system('ffmpeg -y -f image2 -r 47.952 -i {:s}/%08d.png -i {:s} -c copy -map 0:0 -map 1:1? -crf 3 -c:v libx264 -q:v 8 {:s} -hide_banner'.format(interp_output_path, source_path, output_path));"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aS6Kwj09T1DV"
      },
      "source": [
        "Place ```.mp4``` videos in ```/input``` and result will be in ```/output```. Run it twice in a row.\n",
        "\n",
        "If you want to change framerate or ffmpeg settings, do that in ```test.py```. \n",
        "\n",
        "Don't forget to input 2x FPS in the ffmpeg command."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "M2hca_QyHcxs",
        "cellView": "form"
      },
      "source": [
        "#@title Interpolate\n",
        "%cd /content/Arxiv2020-RIFE\n",
        "!python test.py"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hexyye_NUJ7p",
        "cellView": "form"
      },
      "source": [
        "#@title Delete all files\n",
        "%cd /content/\n",
        "!sudo rm -rf /content/Arxiv2020-RIFE/input\n",
        "!sudo rm -rf /content/Arxiv2020-RIFE/output\n",
        "!sudo rm -rf /content/Arxiv2020-RIFE/tmp\n",
        "\n",
        "!mkdir /content/Arxiv2020-RIFE/input\n",
        "!mkdir /content/Arxiv2020-RIFE/output\n",
        "!mkdir /content/Arxiv2020-RIFE/tmp"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}
